# === load relevant libraries
library(data.table)
library(ggplot2)
library(DescTools)
library(plm)
library(lfe)
library(dplyr)
library(reshape)

# === First, compute cumulative flow and return paths for top 1 - bottom 1, and top 3 - bottom 3 paths
rm(list = ls())

# style-level data
data = readRDS('input_data/morningstar_style_data.RDS')
data = data[, list(yyyymm, category, ret, flow, expSum)]

# in each month, sort styles by exponential sum of ratings
sortData = data[!is.na(expSum), list(yyyymm, category, expSum)]
sortData = sortData[0 == rowSums(is.na(sortData))]
sortData = sortData[order(yyyymm, expSum)]
for (this in unique(sortData[, yyyymm])){
  sortData[yyyymm == this, bin := 1:9]
}
rm(this)

# create flow and return paths for subsequent 36 months
data[, expSum := NULL]
data = as.data.table(melt(data, id.vars = c('yyyymm','category')))
names(data)[3] = 'var'

tmp = unique(data[, list(yyyymm)])
tmp[, idx := 1:nrow(tmp)]
data = merge(data, tmp); rm(tmp)

out = data.table()
for (i in 1:36){
  out = rbind(out, data[, list(idx = idx-i, hor = i, category, var, value)])
}
out = merge(out, unique(data[, list(idx, yyyymm)]), by = 'idx')
out = merge(out, sortData[, list(yyyymm, category, bin)], by = c('yyyymm','category'), all.x = T)
out = out[0 == rowSums(is.na(out))]
out[, idx := NULL]
data = copy(out); rm(out, sortData, i)

# summarize by regime: before vs after 2002
data[, period := ifelse(yyyymm > 200206, '2_after 2002', '1_before 2002')]
data = data[, list(value = mean(value)), list(bin, hor, period, var)]
data = data[order(hor)]

# compute differences between top 1 (3) vs bottom 1 (3) styles
tt = data[bin == 9, list(top1 = mean(value)), list(hor, period, var)]
tt = merge(tt, data[bin == 1, list(bottom1 = mean(value)), list(hor, period, var)])
tt = merge(tt, data[bin %in% 7:9, list(top3 = mean(value)), list(hor, period, var)])
tt = merge(tt, data[bin %in% 1:3, list(bottom3 = mean(value)), list(hor, period, var)])
tt = tt[order(hor, period, var)]
tt = tt[, list(hor, period, var, top_minus_bottom = 100*(top1 - bottom1), 
               top3_minus_bottom3 = 100*(top3 - bottom3))]
data = copy(tt); rm(tt)

# summarize by period and output point estimates
tmp = data.table(hor = 1:36)
tmp[, horizon := c(rep('1_months_1-6',6), rep('2_months_7-12',6), rep('3_months_13-24',12), rep('4_months_25-36',12))]
data = merge(data, tmp, by = 'hor'); rm(tmp)

# output: Panel A, flows (point estimates)
cast(data[var == 'flow', round(mean(top_minus_bottom),2), list(period, horizon)], period ~ horizon)

# output: Panel A, returns (point estimates)
cast(data[var == 'ret', round(mean(top_minus_bottom),2), list(period, horizon)], period ~ horizon)

# output: Panel B, flows (point estimates)
cast(data[var == 'flow', round(mean(top3_minus_bottom3),2), list(period, horizon)], period ~ horizon)

# output: Panel B, returns (point estimates)
cast(data[var == 'ret', round(mean(top3_minus_bottom3),2), list(period, horizon)], period ~ horizon)


# Summarize bootstrapped standard errors, computed by permuting styles each year
data = readRDS('input_data/style_price_pressure_bootstraps.RDS')
tmp = data.table(hor = 1:36)
tmp[, horizon := c(rep('1_months_1-6',6), rep('2_months_7-12',6), rep('3_months_13-24',12), rep('4_months_25-36',12))]
data = merge(data, tmp, by = 'hor'); rm(tmp)
data = data[, list(top_minus_bottom = sd(top_minus_bottom), top3_minus_bottom3 = sd(top3_minus_bottom3)), 
            list(period, var, horizon)]


# output: Panel A, flows (standard error)
cast(data[var == 'flow', round(mean(top_minus_bottom),2), list(period, horizon)], period ~ horizon)

# output: Panel A, returns (standard error)
cast(data[var == 'ret', round(mean(top_minus_bottom),2), list(period, horizon)], period ~ horizon)

# output: Panel B, flows (standard error)
cast(data[var == 'flow', round(mean(top3_minus_bottom3),2), list(period, horizon)], period ~ horizon)

# output: Panel B, returns (standard error)
cast(data[var == 'ret', round(mean(top3_minus_bottom3),2), list(period, horizon)], period ~ horizon)
